In [1]:
from hcmus.utils import data_utils

splits = data_utils.get_data_splits()
datasets = data_utils.get_image_datasets_v2(splits, random_margin=0)

[32m2025-07-01 20:55:32.106[0m | [1mINFO    [0m | [36mhcmus.core.appconfig[0m:[36m<module>[0m:[36m7[0m - [1mLoad DotEnv: True[0m
[32m2025-07-01 20:55:33.257[0m | [1mINFO    [0m | [36mhcmus.lbs._label_studio_connector[0m:[36mget_tasks[0m:[36m152[0m - [1mNew `page_to` applied: 35[0m
Loading tasks: 100%|██████████| 35/35 [00:10<00:00,  3.38it/s]
Downloading images: 100%|██████████| 3443/3443 [00:06<00:00, 559.54it/s] 
[32m2025-07-01 20:55:49.915[0m | [1mINFO    [0m | [36mhcmus.lbs._label_studio_connector[0m:[36mget_tasks[0m:[36m152[0m - [1mNew `page_to` applied: 4[0m
Loading tasks: 100%|██████████| 4/4 [00:02<00:00,  1.35it/s]
Downloading images: 100%|██████████| 375/375 [00:01<00:00, 200.44it/s]
[32m2025-07-01 20:55:54.844[0m | [1mINFO    [0m | [36mhcmus.lbs._label_studio_connector[0m:[36mget_tasks[0m:[36m152[0m - [1mNew `page_to` applied: 3[0m
Loading tasks: 100%|██████████| 3/3 [00:01<00:00,  1.98it/s]
Downloading images: 100%|██████████|

In [2]:
from tqdm import tqdm
from torchvision import transforms as T
from hcmus.models.backbone import CLIPBackbone

In [7]:
device = "mps"
backbone_name ="ViT-B/32"
backbone = CLIPBackbone(backbone_name=backbone_name, device=device)

transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor()
])

In [9]:
feature_by_class = {}
for item in tqdm(datasets["train"]):
    image, label, _ = item
    tensor = transform(image).half().to(device)
    if label not in feature_by_class:
        feature_by_class[label] = []
    feature_by_class[label].append(backbone(tensor))

100%|██████████| 2659/2659 [00:57<00:00, 45.91it/s] 


In [12]:
feature_by_class[0][0].shape

torch.Size([1, 512])

In [None]:
import torch
import numpy as np
from itertools import combinations

def compute_distances(class_embeddings_dict, distance_metric='cosine'):
    """
    Compute inner distance (within classes) and inter distance (between classes)

    Args:
        class_embeddings_dict: Dict[str, List[torch.Tensor]] where each tensor is (1, 512)
        distance_metric: 'cosine', 'euclidean', or 'manhattan'

    Returns:
        dict with inner_distances, inter_distances, and statistics
    """

    def compute_distance(tensor1, tensor2, metric):
        """Compute distance between two tensors"""
        # Flatten tensors to 1D
        vec1 = tensor1.flatten()
        vec2 = tensor2.flatten()

        if metric == 'cosine':
            # Cosine distance = 1 - cosine_similarity
            cos_sim = torch.nn.functional.cosine_similarity(vec1.unsqueeze(0), vec2.unsqueeze(0))
            return 1 - cos_sim.item()
        elif metric == 'euclidean':
            return torch.dist(vec1, vec2, p=2).item()
        elif metric == 'manhattan':
            return torch.dist(vec1, vec2, p=1).item()
        else:
            raise ValueError("Metric must be 'cosine', 'euclidean', or 'manhattan'")

    inner_distances = {}
    inter_distances = {}

    # Compute inner distances (within each class)
    print("Computing inner distances...")
    for class_name, embeddings in class_embeddings_dict.items():
        if len(embeddings) < 2:
            inner_distances[class_name] = []
            print(f"Class '{class_name}' has only {len(embeddings)} sample(s), skipping inner distance")
            continue

        distances = []
        # Compute pairwise distances within the class
        for i, j in combinations(range(len(embeddings)), 2):
            dist = compute_distance(embeddings[i], embeddings[j], distance_metric)
            distances.append(dist)

        inner_distances[class_name] = distances
        print(f"Class '{class_name}': {len(distances)} inner distances computed")

    # Compute inter distances (between different classes)
    print("\nComputing inter distances...")
    class_names = list(class_embeddings_dict.keys())

    for i, j in combinations(range(len(class_names)), 2):
        class1, class2 = class_names[i], class_names[j]
        pair_key = f"{class1}_vs_{class2}"

        distances = []
        # Compute distances between all pairs from different classes
        for emb1 in class_embeddings_dict[class1]:
            for emb2 in class_embeddings_dict[class2]:
                dist = compute_distance(emb1, emb2, distance_metric)
                distances.append(dist)

        inter_distances[pair_key] = distances
        print(f"Classes '{class1}' vs '{class2}': {len(distances)} inter distances computed")

    return {
        'inner_distances': inner_distances,
        'inter_distances': inter_distances,
        'metric': distance_metric
    }

def compute_statistics(distances_dict):
    """Compute statistics for the distance results"""
    stats = {
        'inner_stats': {},
        'inter_stats': {},
        'overall_stats': {}
    }

    # Inner distance statistics
    all_inner_distances = []
    for class_name, distances in distances_dict['inner_distances'].items():
        if distances:  # Skip empty lists
            stats['inner_stats'][class_name] = {
                'mean': np.mean(distances),
                'std': np.std(distances),
                'min': np.min(distances),
                'max': np.max(distances),
                'median': np.median(distances),
                'count': len(distances)
            }
            all_inner_distances.extend(distances)

    # Inter distance statistics
    all_inter_distances = []
    for pair_name, distances in distances_dict['inter_distances'].items():
        stats['inter_stats'][pair_name] = {
            'mean': np.mean(distances),
            'std': np.std(distances),
            'min': np.min(distances),
            'max': np.max(distances),
            'median': np.median(distances),
            'count': len(distances)
        }
        all_inter_distances.extend(distances)

    # Overall statistics
    if all_inner_distances:
        stats['overall_stats']['inner'] = {
            'mean': np.mean(all_inner_distances),
            'std': np.std(all_inner_distances),
            'min': np.min(all_inner_distances),
            'max': np.max(all_inner_distances),
            'median': np.median(all_inner_distances)
        }

    if all_inter_distances:
        stats['overall_stats']['inter'] = {
            'mean': np.mean(all_inter_distances),
            'std': np.std(all_inter_distances),
            'min': np.min(all_inter_distances),
            'max': np.max(all_inter_distances),
            'median': np.median(all_inter_distances)
        }

    return stats

def print_summary(distances_dict, stats):
    """Print a summary of the distance analysis"""
    print(f"\n{'='*60}")
    print(f"DISTANCE ANALYSIS SUMMARY (Metric: {distances_dict['metric']})")
    print(f"{'='*60}")

    # Inner distances summary
    print("\nINNER DISTANCES (within classes):")
    print("-" * 40)
    for class_name, class_stats in stats['inner_stats'].items():
        print(f"{class_name:20} | Mean: {class_stats['mean']:.4f} | Std: {class_stats['std']:.4f} | Count: {class_stats['count']}")

    if 'inner' in stats['overall_stats']:
        inner_overall = stats['overall_stats']['inner']
        print(f"\nOverall Inner Distance: Mean={inner_overall['mean']:.4f}, Std={inner_overall['std']:.4f}")

    # Inter distances summary
    print("\nINTER DISTANCES (between classes):")
    print("-" * 40)
    for pair_name, pair_stats in stats['inter_stats'].items():
        print(f"{pair_name:30} | Mean: {pair_stats['mean']:.4f} | Std: {pair_stats['std']:.4f}")

    if 'inter' in stats['overall_stats']:
        inter_overall = stats['overall_stats']['inter']
        print(f"\nOverall Inter Distance: Mean={inter_overall['mean']:.4f}, Std={inter_overall['std']:.4f}")

    # Separability analysis
    if 'inner' in stats['overall_stats'] and 'inter' in stats['overall_stats']:
        inner_mean = stats['overall_stats']['inner']['mean']
        inter_mean = stats['overall_stats']['inter']['mean']
        separability_ratio = inter_mean / inner_mean if inner_mean > 0 else float('inf')
        print(f"\nSEPARABILITY ANALYSIS:")
        print(f"Separability Ratio (Inter/Inner): {separability_ratio:.4f}")
        print("Higher ratio indicates better class separability")

# Example usage:
if __name__ == "__main__":
    # Example data structure
    # example_data = {
    #     'class_A': [torch.randn(1, 512) for _ in range(5)],
    #     'class_B': [torch.randn(1, 512) for _ in range(4)],
    #     'class_C': [torch.randn(1, 512) for _ in range(6)],
    # }
    example_data = feature_by_class

    # Compute distances
    results = compute_distances(example_data, distance_metric='cosine')

    # Compute statistics
    stats = compute_statistics(results)

    # Print summary
    print_summary(results, stats)

    # Access specific results
    print(f"\nExample - Inner distances for class_A: {results['inner_distances']['class_A'][:3]}...")
    print(f"Example - Inter distances between class_A and class_B: {results['inter_distances']['class_A_vs_class_B'][:3]}...")

Computing inner distances...
Class '15': 276 inner distances computed
Class '37': 435 inner distances computed
Class '90': 780 inner distances computed
Class '77': 378 inner distances computed
Class '0': 300 inner distances computed
Class '22': 325 inner distances computed
Class '47': 325 inner distances computed
Class '71': 231 inner distances computed
Class '75': 351 inner distances computed
Class '39': 595 inner distances computed
Class '28': 190 inner distances computed
Class '60': 300 inner distances computed
Class '72': 253 inner distances computed
Class '3': 276 inner distances computed
Class '18': 1378 inner distances computed
Class '25': 276 inner distances computed
Class '81': 595 inner distances computed
Class '58': 561 inner distances computed
Class '79': 406 inner distances computed
Class '38': 253 inner distances computed
Class '46': 465 inner distances computed
Class '34': 300 inner distances computed
Class '21': 561 inner distances computed
Class '40': 300 inner distanc