In [4]:
import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import open_clip
from tqdm import tqdm
import matplotlib.pyplot as plt
import textwrap
import re

# ------------------------------
# DEBUGGING & CONFIGURATION
# ------------------------------
# This forces CUDA operations to be synchronous for accurate error reporting.
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

CONFIG = {
    "model_name": "ViT-B-16",
    "pretrained": "laion2b_s34b_b88k",
    "csv_train": "C:/Users/yehte/Downloads/Ye Htet/Projects/TikTok/Annotation/fine-tune/train.csv",
    "csv_val": "C:/Users/yehte/Downloads/Ye Htet/Projects/TikTok/Annotation/fine-tune/valid.csv",
    "csv_test": "C:/Users/yehte/Downloads/Ye Htet/Projects/TikTok/Annotation/fine-tune/test.csv",
    "save_path": "finetuned_expression_only_best.pt", # <-- New save path
    "plot_save_dir": "test_results_expression_only", # <-- New plot directory
    "num_plots": 100,
    "batch_size": 32, # Start small for large models
    "epochs": 100,
    "lr": 1e-6,
    "patience": 10,
    "device": "cuda" if torch.cuda.is_available() else "cpu",

    # --- FOCUSED: Prompts for Expression ONLY ---
    "expression_prompts": [
        "a photo of a person with an anger expression.", "a photo of a person with a contempt expression.",
        "a photo of a person with a disgust expression.", "a photo of a person with a happiness expression.",
        "a photo of a person with a fear expression.", "a photo of a person with a sadness expression.",
        "a photo of a person with a surprise expression.", "a photo of a person with a neutral expression."
    ]
}

# --- Create simplified keys from prompts for mapping ---
CONFIG["expression_keys"] = [p.split(" with ")[-1].replace("an ", "").replace("a ", "").replace(" expression.", "") for p in CONFIG["expression_prompts"]]
print(CONFIG["expression_keys"])
# ------------------------------
# Helper Function for Parsing
# ------------------------------
def parse_expression_from_prompt(prompt_text):
    """
    A robust parser to find any of the expression keys in the prompt.
    """
    # Search for any of the expression keys from the list in the prompt text
    for expression_key in CONFIG["expression_keys"]:
        # Use regex with word boundaries (\b) to ensure it matches the whole word/phrase
        if re.search(r'\b' + re.escape(expression_key) + r'\b', prompt_text, re.IGNORECASE):
            return expression_key
    return None

# ------------------------------
# Preprocessing and Caching
# ------------------------------
def preprocess_and_cache_csv(csv_path):
    """
    Parses prompts for expression and saves a cached version.
    """
    cache_path = csv_path.replace(".csv", ".expression_cached.csv")
    if os.path.exists(cache_path):
        print(f"Loading preprocessed expression data from cache: {cache_path}")
        try:
            df_cache = pd.read_csv(cache_path)
            if df_cache.empty:
                 print("Warning: Cached file is empty. Reprocessing.")
            else:
                return df_cache
        except pd.errors.EmptyDataError:
            print(f"Warning: Cached file is empty. Reprocessing.")

    print(f"Preprocessing expression data and caching: {csv_path}")
    try:
        df = pd.read_csv(csv_path)
    except (FileNotFoundError, pd.errors.EmptyDataError):
        print(f"Error or empty file at {csv_path}")
        return None

    new_data = []
    for _, row in tqdm(df.iterrows(), total=len(df), desc=f"Preprocessing {os.path.basename(csv_path)}"):
        expression = parse_expression_from_prompt(row['prompt'])
        if expression:
            row_data = {
                'filepath': row['filepath'],
                'prompt': row['prompt'],
                'expression_idx': CONFIG["expression_keys"].index(expression)
            }
            new_data.append(row_data)

    if not new_data:
        print("Warning: Preprocessing resulted in an empty dataset for expression.")
        return None

    cached_df = pd.DataFrame(new_data)
    cached_df.to_csv(cache_path, index=False)
    return cached_df

# ------------------------------
# Dataset Class
# ------------------------------
class ExpressionDataset(Dataset):
    def __init__(self, df, preprocess):
        self.df = df if df is not None else pd.DataFrame()
        self.preprocess = preprocess

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image_path = row["filepath"]
        try:
            image = Image.open(image_path).convert("RGB")
            image = self.preprocess(image)
        except FileNotFoundError:
            image = torch.zeros((3, 224, 224))

        expression_idx = torch.tensor(row["expression_idx"])
        return image, expression_idx, image_path

# ------------------------------
# Training and Evaluation Functions
# ------------------------------
def train_and_validate(model, train_loader, val_loader, expression_text_features):
    optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG["lr"])
    loss_fn = torch.nn.CrossEntropyLoss()
    best_val_loss = float("inf")
    patience_counter = 0
    model_saved = False

    for epoch in range(CONFIG["epochs"]):
        model.train()
        total_train_loss = 0
        for images, gt_indices, _ in tqdm(train_loader, desc=f"Epoch {epoch+1}/{CONFIG['epochs']} [Train]"):
            images, gt_indices = images.to(CONFIG["device"]), gt_indices.to(CONFIG["device"])

            image_features = model.encode_image(images)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)

            logits_per_image = (100.0 * image_features @ expression_text_features.T)
            loss = loss_fn(logits_per_image, gt_indices)

            optimizer.zero_grad(); loss.backward(); optimizer.step()
            total_train_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_loader)

        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for images, gt_indices, _ in tqdm(val_loader, desc=f"Epoch {epoch+1}/{CONFIG['epochs']} [Val]"):
                images, gt_indices = images.to(CONFIG["device"]), gt_indices.to(CONFIG["device"])
                image_features = model.encode_image(images)
                image_features = image_features / image_features.norm(dim=-1, keepdim=True)
                logits_per_image = (100.0 * image_features @ expression_text_features.T)
                loss = loss_fn(logits_per_image, gt_indices)
                total_val_loss += loss.item()

        avg_val_loss = total_val_loss / len(val_loader)
        print(f"✅ Epoch {epoch+1}: Train Loss = {avg_train_loss:.4f}, Val Loss = {avg_val_loss:.4f}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), CONFIG["save_path"])
            print(f"🎉 Saved best model to {CONFIG['save_path']}")
            patience_counter = 0
            model_saved = True
        else:
            patience_counter += 1
            print(f"⚠️ No improvement. Patience: {patience_counter}/{CONFIG['patience']}")
        if patience_counter >= CONFIG["patience"]:
            print("🛑 Early stopping triggered."); break

    return model_saved


def test_and_plot(model, expression_text_features, preprocess):
    print("\n--- Starting Final Testing and Plotting Phase for Expression ---")
    model.load_state_dict(torch.load(CONFIG["save_path"])); model.to(CONFIG["device"]).eval()
    print("Best model loaded.")

    test_df = preprocess_and_cache_csv(CONFIG["csv_test"])
    if test_df is None or test_df.empty:
        print(f"Error or empty data in test CSV. Halting testing.")
        return

    test_dataset = ExpressionDataset(test_df, preprocess)
    test_loader = DataLoader(test_dataset, batch_size=CONFIG["batch_size"], shuffle=False)

    correct_predictions, total_samples = 0, 0
    all_results = []

    with torch.no_grad():
        for images, gt_indices, image_paths in tqdm(test_loader, desc="[Testing Expression]"):
            images = images.to(CONFIG["device"])
            image_features = model.encode_image(images)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)

            logits = (100.0 * image_features @ expression_text_features.T)
            preds = logits.argmax(dim=-1)

            for i in range(len(images)):
                is_correct = (preds[i] == gt_indices[i]).item()
                if is_correct: correct_predictions += 1

                all_results.append({
                    "image_path": image_paths[i],
                    "ground_truth": CONFIG["expression_keys"][gt_indices[i]],
                    "prediction": CONFIG["expression_keys"][preds[i]],
                    "is_correct": is_correct
                })
            total_samples += len(images)

    if total_samples > 0:
        accuracy = (correct_predictions / total_samples) * 100
        print(f"\n📊 Expression Test Accuracy: {accuracy:.2f}%")
    else:
        print("No valid samples were processed in the test set.")

    print(f"\n--- Plotting up to {CONFIG['num_plots']} expression results ---")
    os.makedirs(CONFIG["plot_save_dir"], exist_ok=True)

    for i, result in enumerate(all_results):
        if i >= CONFIG["num_plots"]: break
        try:
            img = Image.open(result["image_path"])
        except FileNotFoundError: continue

        fig, ax = plt.subplots(figsize=(8, 10))
        ax.imshow(img); ax.axis("off")

        title = f"Result {i+1}: {'CORRECT' if result['is_correct'] else 'INCORRECT'}"
        color = 'green' if result['is_correct'] else 'red'
        fig.suptitle(title, fontsize=16, color=color)

        text = (f"Ground Truth: {result['ground_truth']}\n"
                f"Prediction:   {result['prediction']}")

        ax.set_title(text, fontsize=12, pad=10)

        save_name = f"result_{i+1}_{'correct' if result['is_correct'] else 'incorrect'}.png"
        plt.savefig(os.path.join(CONFIG["plot_save_dir"], save_name), bbox_inches='tight')
        plt.close(fig)

    print("--- Plotting complete ---")

# ==============================
#      MAIN EXECUTION BLOCK
# ==============================
if __name__ == '__main__':
    print(f"Using device: {CONFIG['device']}")
    if torch.cuda.is_available():
        print(f"PyTorch Version: {torch.__version__}, CUDA Version: {torch.version.cuda}")

    model, _, preprocess = open_clip.create_model_and_transforms(
        CONFIG["model_name"], pretrained=CONFIG["pretrained"], device=CONFIG["device"]
    )
    tokenizer = open_clip.get_tokenizer(CONFIG["model_name"])

    with torch.no_grad():
        expression_text_tokens = tokenizer(CONFIG["expression_prompts"]).to(CONFIG["device"])
        expression_text_features = model.encode_text(expression_text_tokens)
        expression_text_features = expression_text_features / expression_text_features.norm(dim=-1, keepdim=True)

    train_df = preprocess_and_cache_csv(CONFIG["csv_train"])
    val_df = preprocess_and_cache_csv(CONFIG["csv_val"])

    if train_df is not None and val_df is not None and not train_df.empty and not val_df.empty:
        train_dataset = ExpressionDataset(train_df, preprocess)
        val_dataset = ExpressionDataset(val_df, preprocess)
        train_loader = DataLoader(train_dataset, batch_size=CONFIG["batch_size"], shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=CONFIG["batch_size"], shuffle=False)

        training_successful = train_and_validate(model, train_loader, val_loader, expression_text_features)

        if training_successful:
            test_and_plot(model, expression_text_features, preprocess)
        else:
            print("\nSkipping testing phase: No model was saved during training.")
    else:
        print("\nSkipping training: Training/validation datasets are empty or could not be loaded.")


['anger', 'contempt', 'disgust', 'happiness', 'fear', 'sadness', 'surprise', 'neutral']
Using device: cuda
PyTorch Version: 2.8.0.dev20250507+cu128, CUDA Version: 12.8
Loading preprocessed expression data from cache: C:/Users/yehte/Downloads/Ye Htet/Projects/TikTok/Annotation/fine-tune/train.expression_cached.csv
Loading preprocessed expression data from cache: C:/Users/yehte/Downloads/Ye Htet/Projects/TikTok/Annotation/fine-tune/valid.expression_cached.csv


Epoch 1/100 [Train]: 100%|██████████| 22/22 [00:10<00:00,  2.07it/s]
Epoch 1/100 [Val]: 100%|██████████| 7/7 [00:01<00:00,  3.59it/s]


✅ Epoch 1: Train Loss = 1.4295, Val Loss = 1.1663
🎉 Saved best model to finetuned_expression_only_best.pt


Epoch 2/100 [Train]: 100%|██████████| 22/22 [00:10<00:00,  2.07it/s]
Epoch 2/100 [Val]: 100%|██████████| 7/7 [00:01<00:00,  3.62it/s]


✅ Epoch 2: Train Loss = 0.9854, Val Loss = 1.0611
🎉 Saved best model to finetuned_expression_only_best.pt


Epoch 3/100 [Train]: 100%|██████████| 22/22 [00:10<00:00,  2.07it/s]
Epoch 3/100 [Val]: 100%|██████████| 7/7 [00:01<00:00,  3.58it/s]


✅ Epoch 3: Train Loss = 0.8577, Val Loss = 1.0232
🎉 Saved best model to finetuned_expression_only_best.pt


Epoch 4/100 [Train]: 100%|██████████| 22/22 [00:10<00:00,  2.07it/s]
Epoch 4/100 [Val]: 100%|██████████| 7/7 [00:01<00:00,  3.53it/s]


✅ Epoch 4: Train Loss = 0.7617, Val Loss = 1.0309
⚠️ No improvement. Patience: 1/10


Epoch 5/100 [Train]: 100%|██████████| 22/22 [00:10<00:00,  2.06it/s]
Epoch 5/100 [Val]: 100%|██████████| 7/7 [00:01<00:00,  3.61it/s]


✅ Epoch 5: Train Loss = 0.6756, Val Loss = 1.0771
⚠️ No improvement. Patience: 2/10


Epoch 6/100 [Train]: 100%|██████████| 22/22 [00:10<00:00,  2.07it/s]
Epoch 6/100 [Val]: 100%|██████████| 7/7 [00:01<00:00,  3.55it/s]


✅ Epoch 6: Train Loss = 0.5772, Val Loss = 1.0926
⚠️ No improvement. Patience: 3/10


Epoch 7/100 [Train]: 100%|██████████| 22/22 [00:10<00:00,  2.05it/s]
Epoch 7/100 [Val]: 100%|██████████| 7/7 [00:01<00:00,  3.54it/s]


✅ Epoch 7: Train Loss = 0.4924, Val Loss = 1.1450
⚠️ No improvement. Patience: 4/10


Epoch 8/100 [Train]: 100%|██████████| 22/22 [00:10<00:00,  2.05it/s]
Epoch 8/100 [Val]: 100%|██████████| 7/7 [00:01<00:00,  3.53it/s]


✅ Epoch 8: Train Loss = 0.4075, Val Loss = 1.1853
⚠️ No improvement. Patience: 5/10


Epoch 9/100 [Train]: 100%|██████████| 22/22 [00:10<00:00,  2.03it/s]
Epoch 9/100 [Val]: 100%|██████████| 7/7 [00:01<00:00,  3.54it/s]


✅ Epoch 9: Train Loss = 0.3336, Val Loss = 1.2399
⚠️ No improvement. Patience: 6/10


Epoch 10/100 [Train]: 100%|██████████| 22/22 [00:10<00:00,  2.04it/s]
Epoch 10/100 [Val]: 100%|██████████| 7/7 [00:01<00:00,  3.52it/s]


✅ Epoch 10: Train Loss = 0.2613, Val Loss = 1.2721
⚠️ No improvement. Patience: 7/10


Epoch 11/100 [Train]: 100%|██████████| 22/22 [00:10<00:00,  2.02it/s]
Epoch 11/100 [Val]: 100%|██████████| 7/7 [00:02<00:00,  3.49it/s]


✅ Epoch 11: Train Loss = 0.2056, Val Loss = 1.3860
⚠️ No improvement. Patience: 8/10


Epoch 12/100 [Train]: 100%|██████████| 22/22 [00:10<00:00,  2.02it/s]
Epoch 12/100 [Val]: 100%|██████████| 7/7 [00:01<00:00,  3.60it/s]


✅ Epoch 12: Train Loss = 0.1584, Val Loss = 1.4373
⚠️ No improvement. Patience: 9/10


Epoch 13/100 [Train]: 100%|██████████| 22/22 [00:10<00:00,  2.02it/s]
Epoch 13/100 [Val]: 100%|██████████| 7/7 [00:02<00:00,  3.42it/s]


✅ Epoch 13: Train Loss = 0.1170, Val Loss = 1.5098
⚠️ No improvement. Patience: 10/10
🛑 Early stopping triggered.

--- Starting Final Testing and Plotting Phase for Expression ---
Best model loaded.
Loading preprocessed expression data from cache: C:/Users/yehte/Downloads/Ye Htet/Projects/TikTok/Annotation/fine-tune/test.expression_cached.csv


[Testing Expression]: 100%|██████████| 4/4 [00:01<00:00,  3.97it/s]



📊 Expression Test Accuracy: 59.00%

--- Plotting up to 100 expression results ---
--- Plotting complete ---
