In [7]:
import torch
import clip

def validate_features(
    class_pairs: list,
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
):
    """
    Validates whether correct features increase similarity to the fine-grained class
    and wrong features decrease it using the CLIP text encoder for multiple class pairs.

    Args:
        class_pairs (list): A list of dictionaries, each containing:
            - generic_class (str): The generic class description.
            - fine_grained_class (str): The fine-grained class.
            - correct_features (list): List of features that correctly describe the fine-grained class.
            - wrong_features (list): List of features that incorrectly describe the fine-grained class.
        device (str): Device to run the model on ("cuda" or "cpu").

    Returns:
        None
    """
    # Load the CLIP model
    model, preprocess = clip.load("ViT-B/32", device=device)

    for pair in class_pairs:
        generic_class = pair['generic_class']
        fine_grained_class = pair['fine_grained_class']
        correct_features = pair.get('correct_features', [])
        wrong_features = pair.get('wrong_features', [])

        print(f"\n=== Evaluating: '{generic_class}' vs '{fine_grained_class}' ===\n")

        # Create descriptions
        baseline_description = generic_class
        fine_grained_description = fine_grained_class

        correct_descriptions = [f"{generic_class} {feature}" for feature in correct_features]
        wrong_descriptions = [f"{generic_class} {feature}" for feature in wrong_features]

        # Combine all descriptions
        all_descriptions = (
            [baseline_description] +
            correct_descriptions +
            wrong_descriptions +
            [fine_grained_description]
        )

        # Tokenize and encode the texts
        text_tokens = clip.tokenize(all_descriptions).to(device)
        with torch.no_grad():
            text_embeddings = model.encode_text(text_tokens)
            text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)  # Normalize embeddings

        # Compute cosine similarities
        similarity_matrix = text_embeddings @ text_embeddings.T
        baseline_idx = 0
        fine_grained_idx = len(all_descriptions) - 1

        baseline_similarity = similarity_matrix[baseline_idx, fine_grained_idx].item()

        # Extract similarity scores
        correct_scores = similarity_matrix[1:1+len(correct_descriptions), fine_grained_idx].cpu().numpy()
        wrong_scores = similarity_matrix[1+len(correct_descriptions):-1, fine_grained_idx].cpu().numpy()
        fine_grained_score = similarity_matrix[fine_grained_idx, fine_grained_idx].item()  # Should be 1.0

        # Display the results
        print(f"Baseline Similarity between '{generic_class}' and '{fine_grained_class}': {baseline_similarity:.4f}\n")

        if correct_features:
            print("Correct Features:")
            for desc, score in zip(correct_descriptions, correct_scores):
                improvement = score - baseline_similarity
                print(f"  '{desc}': {score:.4f} (Δ: {improvement:+.4f})")
        else:
            print("No Correct Features Provided.")

        if wrong_features:
            print("\nWrong Features:")
            for desc, score in zip(wrong_descriptions, wrong_scores):
                change = score - baseline_similarity
                print(f"  '{desc}': {score:.4f} (Δ: {change:+.4f})")
        else:
            print("\nNo Wrong Features Provided.")

        # print(f"\nFine-grained Class Description ('{fine_grained_class}'): {fine_grained_score:.4f}\n")

if __name__ == "__main__":
    # Define multiple class pairs with their features
    class_pairs = [
        {
            "generic_class": "a photo of bird",
            "fine_grained_class": "a photo of eagle",
            "correct_features": [
                "with brown feathers",
                "a bird of prey",
                "with a heavy head",
                "with a sharp beak"
            ],
            "wrong_features": [
                "with blue features",
                "with black and white stripes",
                "with a long neck",
                "with colorful plumes"
            ]
        },
        {
            "generic_class": "a photo of bird",
            "fine_grained_class": "a photo of penguin",
            "correct_features": [
                "with black and white feathers",
                "a flightless bird",
                "with a tuxedo-like appearance",
                "living in cold environments"
            ],
            "wrong_features": [
                "with bright red feathers",
                "a bird of prey",
                "with a long beak",
                "flying at high altitudes"
            ]
        },
        {
            "generic_class": "a photo of car",
            "fine_grained_class": "a photo of Jeep",
            "correct_features": [
                "with off-road tires",
                "a rugged vehicle",
                "with a boxy shape",
                "designed for rough terrains"
            ],
            "wrong_features": [
                "with sleek curves",
                "a convertible model",
                "with bright colors",
                "designed for high speed racing"
            ]
        },
        {
            "generic_class": "a photo of horse",
            "fine_grained_class": "a photo of zebra",
            "correct_features": [
                "with black and white stripes",
                "a striped equine",
                "with a unique pattern",
                "distinctive black and white markings"
            ],
            "wrong_features": [
                "with brown spots",
                "a pure white coat",
                "with long flowing manes",
                "a solid-colored horse"
            ]
        }
    ]

    validate_features(class_pairs)



=== Evaluating: 'a photo of bird' vs 'a photo of eagle' ===

Baseline Similarity between 'a photo of bird' and 'a photo of eagle': 0.9062

Correct Features:
  'a photo of bird with brown feathers': 0.7891 (Δ: -0.1172)
  'a photo of bird a bird of prey': 0.8745 (Δ: -0.0317)
  'a photo of bird with a heavy head': 0.8550 (Δ: -0.0513)
  'a photo of bird with a sharp beak': 0.8447 (Δ: -0.0615)

Wrong Features:
  'a photo of bird with blue features': 0.7646 (Δ: -0.1416)
  'a photo of bird with black and white stripes': 0.7363 (Δ: -0.1699)
  'a photo of bird with a long neck': 0.8237 (Δ: -0.0825)
  'a photo of bird with colorful plumes': 0.7539 (Δ: -0.1523)

=== Evaluating: 'a photo of bird' vs 'a photo of penguin' ===

Baseline Similarity between 'a photo of bird' and 'a photo of penguin': 0.8496

Correct Features:
  'a photo of bird with black and white feathers': 0.7241 (Δ: -0.1255)
  'a photo of bird a flightless bird': 0.8188 (Δ: -0.0308)
  'a photo of bird with a tuxedo-like appearance

In [5]:
import torch
import clip
from prettytable import PrettyTable

def validate_features(
    class_pairs: list,
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
):
    """
    Validates whether correct features increase similarity to the fine-grained class
    and wrong features decrease it using the CLIP text encoder for multiple class pairs.
    Results are displayed in neatly formatted tables using PrettyTable.

    Args:
        class_pairs (list): A list of dictionaries, each containing:
            - generic_class (str): The generic class description.
            - fine_grained_class (str): The fine-grained class.
            - correct_features (list): List of features that correctly describe the fine-grained class.
            - wrong_features (list): List of features that incorrectly describe the fine-grained class.
        device (str): Device to run the model on ("cuda" or "cpu").

    Returns:
        None
    """
    # Load the CLIP model
    model, preprocess = clip.load("ViT-B/32", device=device)

    for pair in class_pairs:
        generic_class = pair['generic_class']
        fine_grained_class = pair['fine_grained_class']
        correct_features = pair.get('correct_features', [])
        wrong_features = pair.get('wrong_features', [])

        print(f"\n=== Evaluating: '{generic_class}' vs '{fine_grained_class}' ===\n")

        # Create descriptions
        baseline_description = generic_class
        fine_grained_description = fine_grained_class

        correct_descriptions = [f"{generic_class} {feature}" for feature in correct_features]
        wrong_descriptions = [f"{generic_class} {feature}" for feature in wrong_features]

        # Combine all descriptions
        all_descriptions = (
            [baseline_description] +
            correct_descriptions +
            wrong_descriptions +
            [fine_grained_description]
        )

        # Tokenize and encode the texts
        text_tokens = clip.tokenize(all_descriptions).to(device)
        with torch.no_grad():
            text_embeddings = model.encode_text(text_tokens)
            text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)  # Normalize embeddings

        # Compute cosine similarities
        similarity_matrix = text_embeddings @ text_embeddings.T
        baseline_idx = 0
        fine_grained_idx = len(all_descriptions) - 1

        baseline_similarity = similarity_matrix[baseline_idx, fine_grained_idx].item()

        # Extract similarity scores
        correct_scores = similarity_matrix[1:1+len(correct_descriptions), fine_grained_idx].cpu().numpy()
        wrong_scores = similarity_matrix[1+len(correct_descriptions):-1, fine_grained_idx].cpu().numpy()
        fine_grained_score = similarity_matrix[fine_grained_idx, fine_grained_idx].item()  # Should be 1.0

        # Create tables
        table_correct = PrettyTable()
        table_wrong = PrettyTable()
        table_summary = PrettyTable()

        # Configure Correct Features Table
        if correct_features:
            table_correct.field_names = ["Description", "Similarity Score", "Δ from Baseline"]
            for desc, score in zip(correct_descriptions, correct_scores):
                improvement = score - baseline_similarity
                delta = f"{improvement:+.4f}"
                table_correct.add_row([desc, f"{score:.4f}", delta])
        else:
            table_correct = None

        # Configure Wrong Features Table
        if wrong_features:
            table_wrong.field_names = ["Description", "Similarity Score", "Δ from Baseline"]
            for desc, score in zip(wrong_descriptions, wrong_scores):
                change = score - baseline_similarity
                delta = f"{change:+.4f}"
                table_wrong.add_row([desc, f"{score:.4f}", delta])
        else:
            table_wrong = None

        # Configure Summary Table
        table_summary.field_names = ["Metric", "Similarity Score"]
        table_summary.add_row([f"Baseline Similarity\n('{baseline_description}' vs '{fine_grained_class}')", f"{baseline_similarity:.4f}"])
        table_summary.add_row([f"Fine-grained Class Description\n('{fine_grained_class}' vs itself)", f"{fine_grained_score:.4f}"])

        # Print Tables
        print(table_summary)

        if table_correct:
            print("\nCorrect Features:")
            print(table_correct)
        else:
            print("\nNo Correct Features Provided.")

        if table_wrong:
            print("\nWrong Features:")
            print(table_wrong)
        else:
            print("\nNo Wrong Features Provided.")

    print("\n=== Evaluation Complete ===\n")

if __name__ == "__main__":
    # Ensure PrettyTable is installed. If not, instruct the user to install it.
    try:
        from prettytable import PrettyTable
    except ImportError:
        print("PrettyTable is not installed. You can install it using 'pip install prettytable'")
        exit(1)

    # Define multiple class pairs with their features
    class_pairs = [
        {
            "generic_class": "a photo of a bird",
            "fine_grained_class": "eagle",
            "correct_features": [
                "with brown feathers",
                "a bird of prey",
                "with a heavy head",
                "with a sharp beak"
            ],
            "wrong_features": [
                "with blue features",
                "with black and white stripes",
                "with a long neck",
                "with colorful plumes"
            ]
        },
        {
            "generic_class": "a photo of a bird",
            "fine_grained_class": "penguin",
            "correct_features": [
                "with black and white feathers",
                "a flightless bird",
                "with a tuxedo-like appearance",
                "living in cold environments"
            ],
            "wrong_features": [
                "with bright red feathers",
                "a bird of prey",
                "with a long beak",
                "flying at high altitudes"
            ]
        },
        {
            "generic_class": "a photo of a car",
            "fine_grained_class": "Jeep",
            "correct_features": [
                "with off-road tires",
                "a rugged vehicle",
                "with a boxy shape",
                "designed for rough terrains"
            ],
            "wrong_features": [
                "with sleek curves",
                "a convertible model",
                "with bright colors",
                "designed for high speed racing"
            ]
        },
        {
            "generic_class": "a photo of a horse",
            "fine_grained_class": "zebra",
            "correct_features": [
                "with black and white stripes",
                "a striped equine",
                "with a unique pattern",
                "distinctive black and white markings"
            ],
            "wrong_features": [
                "with brown spots",
                "a pure white coat",
                "with long flowing manes",
                "a solid-colored horse"
            ]
        }
    ]

    validate_features(class_pairs)



=== Evaluating: 'a photo of a bird' vs 'eagle' ===

+----------------------------------+------------------+
|              Metric              | Similarity Score |
+----------------------------------+------------------+
|       Baseline Similarity        |      0.8657      |
| ('a photo of a bird' vs 'eagle') |                  |
|  Fine-grained Class Description  |      1.0000      |
|       ('eagle' vs itself)        |                  |
+----------------------------------+------------------+

Correct Features:
+---------------------------------------+------------------+-----------------+
|              Description              | Similarity Score | Δ from Baseline |
+---------------------------------------+------------------+-----------------+
| a photo of a bird with brown feathers |      0.7563      |     -0.1094     |
|    a photo of a bird a bird of prey   |      0.8535      |     -0.0122     |
|  a photo of a bird with a heavy head  |      0.8105      |     -0.0552     |
|  a p