In [None]:
# %%
# =========================
# Import Necessary Libraries
# =========================
import torch
import numpy as np
from transformers import CLIPProcessor, CLIPModel
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from typing import List
from matplotlib.lines import Line2D
from prettytable import PrettyTable

# %%
# =========================
# Load the CLIP Model and Processor
# =========================
MODEL_NAME = "openai/clip-vit-large-patch14"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = CLIPModel.from_pretrained(MODEL_NAME).to(device)
processor = CLIPProcessor.from_pretrained(MODEL_NAME)

# %%
# =========================
# Define Birds and Their Visual Features
# =========================

# Famous birds and their VISUAL features
famous_birds = {
    "penguin": [
        "black and white feathers",
        "flipper-like wings",
        "in a snowy environment"
    ],
    "eagle": [
        "brown feathers",
        "large wingspan",
        "sharp hooked beak",
    ],
    "flamingo": [
        "pink plumage",
        "long stilt-like legs",
        "S-shaped neck",
    ],
    "owl": [
        "tufts on head resembling ears",
        "camouflaged feathers",
        "silent flight wings"
    ],
    "hummingbird": [
        "long slender bill",
        "rapidly beating wings",
        "hovering in mid-air"
    ]
}

# Fine-grained warblers and their visual features
warblers = {
    "yellow warbler": [
        "bright yellow plumage",
        "thin pointed beak",
        "olive-green streaks on chest"
    ],
    "black-throated blue warbler": [
        "blue upperparts",
        "black throat and face",
        "white belly"
    ],
    "pine warbler": [
        "olive-green upperparts",
        "white wing bars",
        "yellow throat and breast"
    ]
}

# General bird prompt
general_bird_prompt = "a photo of a bird"
general_warbler_prompt = "a photo of a warbler"
# %%
# =========================
# Helper Functions
# =========================

def get_text_embedding(texts: List[str]) -> torch.Tensor:
    """
    Compute the text embeddings for a list of texts using CLIP.

    Args:
        texts (List[str]): List of text prompts.

    Returns:
        torch.Tensor: Tensor of shape (len(texts), embedding_dim).
    """
    inputs = processor(text=texts, return_tensors="pt", padding=True, truncation=True).to(device)
    with torch.no_grad():
        text_embeddings = model.get_text_features(**inputs)
    # Normalize embeddings
    text_embeddings = text_embeddings / text_embeddings.norm(p=2, dim=-1, keepdim=True)
    return text_embeddings

def compute_cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> float:
    """
    Compute the cosine similarity between two tensors.

    Args:
        a (torch.Tensor): Tensor of shape (embedding_dim).
        b (torch.Tensor): Tensor of shape (embedding_dim).

    Returns:
        float: Cosine similarity value.
    """
    a = a / a.norm(p=2, dim=-1, keepdim=True)
    b = b / b.norm(p=2, dim=-1, keepdim=True)
    return torch.nn.functional.cosine_similarity(a, b, dim=-1).item()

# %%
# =========================
# Helper Functions (Updated)
# =========================

def display_similarity_table(results: dict, title: str, num_features: int):
    """
    Display the similarity results in a table using PrettyTable, 
    bolding the maximum similarity in each row.

    Args:
        results (dict): Dictionary containing similarity results.
        title (str): Title for the table.
        num_features (int): Number of features per bird.
    """
    table = PrettyTable()
    table.title = title
    field_names = ["Bird", "Similarity with General Bird"]
    for i in range(1, num_features + 1):
        field_names.append(f"After Feature {i}")
    field_names.extend(["Combined Description Text", "Combined Direction Vectors"])
    table.field_names = field_names

    for bird, values in results.items():
        # Extract only float values for comparison
        float_values = [v for v in values if isinstance(v, float)]
        if float_values:
            max_value = max(float_values)
        else:
            max_value = None

        # Apply bold to the max value
        formatted_values = []
        for v in values:
            if isinstance(v, float) and v == max_value:
                # Bold using ANSI escape codes
                formatted_v = f"\033[1m{v:.4f}\033[0m^*"
            elif isinstance(v, float):
                formatted_v = f"{v:.4f}"
            else:
                formatted_v = "-"
            formatted_values.append(formatted_v)
        
        # Capitalize the bird name
        row = [bird.capitalize()] + formatted_values
        table.add_row(row)
    
    print(table)

# %%
# =========================
# Process Birds (Famous Birds and Warblers)
# =========================

def process_birds(general_bird_prompt, birds_dict, is_fine_grained=False):
    """
    Process birds to compute embeddings and similarities.

    Args:
        birds_dict (dict): Dictionary of birds and their features.
        is_fine_grained (bool): Whether the birds are fine-grained species.
    """
    # Compute general bird embedding
    general_bird_embedding = get_text_embedding([general_bird_prompt])

    # Compute bird embeddings
    bird_prompts = [f"a photo of a {bird}" for bird in birds_dict.keys()]
    bird_embeddings = get_text_embedding(bird_prompts)

    # Create a mapping from bird name to embedding
    bird_embeddings_dict = dict(zip(birds_dict.keys(), bird_embeddings))

    # For each bird, compute embeddings for directional texts and combined description text
    directional_embeddings = {}
    combined_features_embeddings = {}
    directional_vectors = {}

    similarity_results = {}

    for bird, features in birds_dict.items():
        # Embeddings for each directional text
        directional_texts = [f"{general_bird_prompt} with {feature}" for feature in features]
        embeddings = get_text_embedding(directional_texts)
        directional_embeddings[bird] = embeddings

        # Compute directional vectors (difference between embeddings)
        vectors = embeddings - general_bird_embedding
        # Normalize vectors
        vectors = vectors / vectors.norm(p=2, dim=-1, keepdim=True)
        directional_vectors[bird] = vectors

        # Combined description text embedding
        combined_text = f"A bird with {', '.join(features[:-1])}, and {features[-1]}."
        combined_embedding = get_text_embedding([combined_text])
        combined_features_embeddings[bird] = combined_embedding

        # Decompositional approach
        sum_vector = directional_vectors[bird].sum(dim=0, keepdim=True)
        decomposed_embedding = general_bird_embedding + sum_vector
        decomposed_embedding = decomposed_embedding / decomposed_embedding.norm(p=2, dim=-1, keepdim=True)

        # Compute similarities
        bird_embedding = bird_embeddings_dict[bird]

        sim_general = compute_cosine_similarity(general_bird_embedding, bird_embedding)

        sims_features = []
        for i, feature_embedding in enumerate(directional_embeddings[bird]):
            sim = compute_cosine_similarity(feature_embedding.unsqueeze(0), bird_embedding)
            sims_features.append(sim)

        sim_combined = compute_cosine_similarity(combined_embedding, bird_embedding)
        sim_decomposed = compute_cosine_similarity(decomposed_embedding, bird_embedding)

        # Collect results
        similarity_results[bird] = [sim_general] + sims_features + [sim_combined, sim_decomposed]

    # Determine maximum number of features
    max_num_features = max(len(features) for features in birds_dict.values())

    # Display results in a table
    title = "Similarity Results (Fine-Grained Birds)" if is_fine_grained else "Similarity Results (Famous Birds)"
    display_similarity_table(similarity_results, title, max_num_features)

    # Return embeddings for visualization if needed
    return {
        "general_bird_embedding": general_bird_embedding,
        "bird_embeddings_dict": bird_embeddings_dict,
        "directional_embeddings": directional_embeddings,
        "combined_features_embeddings": combined_features_embeddings,
        "decomposed_embeddings": {bird: (general_bird_embedding + directional_vectors[bird].sum(dim=0, keepdim=True)).squeeze(0).cpu().numpy() for bird in birds_dict.keys()},
        "directional_vectors": directional_vectors
    }

# %%
# =========================
# Process and Display Results for Famous Birds
# =========================

famous_birds_results = process_birds(general_bird_prompt, famous_birds)

# %%
# =========================
# Process and Display Results for Fine-Grained Warblers
# =========================

warblers_results = process_birds(general_warbler_prompt, warblers, is_fine_grained=True)

# %%
# =========================
# Visualization (Optional)
# =========================

# You can use the returned embeddings from process_birds function for visualization if needed.
