In [None]:
import torch
from PIL import Image
import open_clip
import os

# Set your model configuration
MODEL_NAME = "ViT-B-32"
CHECKPOINT_PATH = "your_finetuned_clip_checkpoint.pt"  # <-- replace with your fine-tuned model path
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Load model and preprocessing
model, _, preprocess = open_clip.create_model_and_transforms(MODEL_NAME, pretrained=CHECKPOINT_PATH)
model = model.to(DEVICE).eval()
tokenizer = open_clip.get_tokenizer(MODEL_NAME)

# ----- Define Prompt Sets -----
IDENTITY_LIST = ["Daniel", "Yuki", "Alex", "Emma", "John"]
AGE_PROMPTS = [
    "A child",
    "A teenager",
    "A 25-year-old adult",
    "A 40-year-old adult",
    "A 60-year-old senior"
]
GENDER_PROMPTS = [
    "A man",
    "A woman",
    "A non-binary person"
]
EXPRESSION_PROMPTS = [
    "A person with a happy expression",
    "A person with a sad expression",
    "A person with an angry expression",
    "A person with a surprised expression",
    "A person with a neutral expression"
]

IDENTITY_PROMPTS = [f"A person named {name}" for name in IDENTITY_LIST]

# ----- Helper Functions -----
def encode_text_prompts(prompts, tokenizer, model):
    tokens = tokenizer(prompts).to(DEVICE)
    with torch.no_grad():
        text_features = model.encode_text(tokens)
        text_features /= text_features.norm(dim=-1, keepdim=True)
    return text_features

def encode_image(image_path, preprocess, model):
    image = Image.open(image_path).convert("RGB")
    image_tensor = preprocess(image).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        image_features = model.encode_image(image_tensor)
        image_features /= image_features.norm(dim=-1, keepdim=True)
    return image_features

def predict_attribute(image_features, text_features, prompt_list, top_k=1):
    similarities = (image_features @ text_features.T).squeeze(0)
    topk = similarities.topk(top_k)
    return [(prompt_list[i], similarities[i].item()) for i in topk.indices]

# ----- Main Inference Function -----
def predict_face_attributes(image_path):
    print(f"\n🔍 Predicting attributes for: {image_path}")

    # Encode image
    image_features = encode_image(image_path, preprocess, model)

    # Encode text prompts
    identity_feats = encode_text_prompts(IDENTITY_PROMPTS, tokenizer, model)
    age_feats = encode_text_prompts(AGE_PROMPTS, tokenizer, model)
    gender_feats = encode_text_prompts(GENDER_PROMPTS, tokenizer, model)
    expr_feats = encode_text_prompts(EXPRESSION_PROMPTS, tokenizer, model)

    # Predict
    identity_pred = predict_attribute(image_features, identity_feats, IDENTITY_PROMPTS)[0]
    age_pred = predict_attribute(image_features, age_feats, AGE_PROMPTS)[0]
    gender_pred = predict_attribute(image_features, gender_feats, GENDER_PROMPTS)[0]
    expression_pred = predict_attribute(image_features, expr_feats, EXPRESSION_PROMPTS)[0]

    return {
        "identity": identity_pred[0],
        "identity_score": identity_pred[1],
        "age": age_pred[0],
        "age_score": age_pred[1],
        "gender": gender_pred[0],
        "gender_score": gender_pred[1],
        "expression": expression_pred[0],
        "expression_score": expression_pred[1]
    }

# ----- CLI Example -----
if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Multi-Attribute Face Predictor using OpenCLIP")
    parser.add_argument("image_path", type=str, help="Path to face image")
    args = parser.parse_args()

    if not os.path.exists(args.image_path):
        print("❌ Image file does not exist.")
        exit(1)

    results = predict_face_attributes(args.image_path)
    print("\n🧠 Prediction Results:")
    for k, v in results.items():
        print(f"{k:>18}: {v}")
