# 🚀 FLOSS: Free Lunch in Open-vocabulary Semantic Segmentation

## Reproducibility Notebook

This notebook allows you to reproduce the template ranking results from the FLOSS paper.

**Paper:** [FLOSS: Free Lunch in Open-vocabulary Semantic Segmentation](https://arxiv.org/abs/2504.10487)  
**Project Page:** [https://yasserben.github.io/FLOSS/](https://yasserben.github.io/FLOSS/)  
**GitHub:** [https://github.com/yasserben/FLOSS](https://github.com/yasserben/FLOSS)

---

### 📋 What This Notebook Does

1. **Load pre-computed features** from HuggingFace (or compute them on-the-fly)
2. **Compute template rankings** using entropy (our method) or your custom metric
3. **Reproduce the paper's results** or experiment with new ranking approaches
4. **Visualize** the rankings and compare different metrics

### 💡 Key Insight from FLOSS

> For each class, there exist single-template classifiers that significantly outperform the conventional averaged classifier using all 80 templates.

We use **entropy** as an unsupervised metric to identify the best template for each class, without requiring any ground-truth labels!

## 📦 Setup & Installation

In [None]:
# Install Dependencies (Run this first!)
import sys
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    print("🔧 Installing dependencies for Google Colab...")
    !pip install -q torch torchvision --index-url https://download.pytorch.org/whl/cu121
    !pip install -q open-clip-torch==2.24.0
    !pip install -q einops ftfy regex tqdm omegaconf
    !pip install -q huggingface_hub datasets
    !pip install -q matplotlib seaborn pandas
    
    import os
    if not os.path.exists('/content/FLOSS'):
        !git clone https://github.com/yasserben/FLOSS.git /content/FLOSS
    
    sys.path.insert(0, '/content/FLOSS')
    %cd /content/FLOSS
    print("✅ Colab setup complete!")
else:
    print("📍 Running locally.")

In [None]:
# Import Libraries
import os
import json
import numpy as np
from pathlib import Path
from typing import List, Dict, Optional
from collections import defaultdict
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🖥️ Using device: {DEVICE}")

## 🎨 Configuration

In [None]:
# Configuration
DATASET = "cityscapes"  # Options: cityscapes, pascalvoc20, pascalco59, ade20k, cocostuff
MODEL_TYPE = "maskclip"  # Options: maskclip, naclip, clipdinoiser
USE_PRECOMPUTED = True  # If True, downloads from HuggingFace
HF_REPO = "yasserben/floss-features"

print(f"📊 Configuration:")
print(f"   Dataset: {DATASET}")
print(f"   Model: {MODEL_TYPE}")

In [None]:
# ImageNet Templates (80 prompts)
IMAGENET_TEMPLATES = [
    'a bad photo of a {}.', 'a photo of many {}.', 'a sculpture of a {}.',
    'a photo of the hard to see {}.', 'a low resolution photo of the {}.',
    'a rendering of a {}.', 'graffiti of a {}.', 'a bad photo of the {}.',
    'a cropped photo of the {}.', 'a tattoo of a {}.', 'the embroidered {}.',
    'a photo of a hard to see {}.', 'a bright photo of a {}.', 'a photo of a clean {}.',
    'a photo of a dirty {}.', 'a dark photo of the {}.', 'a drawing of a {}.',
    'a photo of my {}.', 'the plastic {}.', 'a photo of the cool {}.',
    'a close-up photo of a {}.', 'a black and white photo of the {}.',
    'a painting of the {}.', 'a painting of a {}.', 'a pixelated photo of the {}.',
    'a sculpture of the {}.', 'a bright photo of the {}.', 'a cropped photo of a {}.',
    'a plastic {}.', 'a photo of the dirty {}.', 'a jpeg corrupted photo of a {}.',
    'a blurry photo of the {}.', 'a photo of the {}.', 'a good photo of the {}.',
    'a rendering of the {}.', 'a {} in a video game.', 'a photo of one {}.',
    'a doodle of a {}.', 'a close-up photo of the {}.', 'a photo of a {}.',
    'the origami {}.', 'the {} in a video game.', 'a sketch of a {}.',
    'a doodle of the {}.', 'a origami {}.', 'a low resolution photo of a {}.',
    'the toy {}.', 'a rendition of the {}.', 'a photo of the clean {}.',
    'a photo of a large {}.', 'a rendition of a {}.', 'a photo of a nice {}.',
    'a photo of a weird {}.', 'a blurry photo of a {}.', 'a cartoon {}.',
    'art of a {}.', 'a sketch of the {}.', 'a embroidered {}.',
    'a pixelated photo of a {}.', 'itap of the {}.', 'a jpeg corrupted photo of the {}.',
    'a good photo of a {}.', 'a plushie {}.', 'a photo of the nice {}.',
    'a photo of the small {}.', 'a photo of the weird {}.', 'the cartoon {}.',
    'art of the {}.', 'a drawing of the {}.', 'a photo of the large {}.',
    'a black and white photo of a {}.', 'the plushie {}.', 'a dark photo of a {}.',
    'itap of a {}.', 'graffiti of the {}.', 'a toy {}.', 'itap of my {}.',
    'a photo of a cool {}.', 'a photo of a small {}.', 'a tattoo of the {}.',
]

print(f"📝 Loaded {len(IMAGENET_TEMPLATES)} templates")

In [None]:
# Dataset Class Names
CLASS_NAMES = {
    "cityscapes": [
        "road", "sidewalk", "building", "wall", "fence", "pole",
        "traffic light", "traffic sign", "vegetation", "terrain", "sky",
        "person", "rider", "car", "truck", "bus", "train", "motorcycle", "bicycle"
    ],
    "pascalvoc20": [
        "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat",
        "chair", "cow", "dining table", "dog", "horse", "motorbike", "person",
        "potted plant", "sheep", "sofa", "train", "tvmonitor"
    ],
}

current_classes = CLASS_NAMES.get(DATASET, CLASS_NAMES["cityscapes"])
NUM_CLASSES = len(current_classes)
print(f"🏷️ {DATASET}: {NUM_CLASSES} classes")

## 🧠 Feature Extraction

In [None]:
# CLIP Text Encoder
class CLIPTextEncoder:
    def __init__(self, model_name="ViT-B-16", pretrained="laion2b_s34b_b88k", device=DEVICE):
        from open_clip import create_model_from_pretrained, get_tokenizer
        self.device = device
        print(f"📥 Loading CLIP text encoder...")
        self.model, _ = create_model_from_pretrained(model_name, pretrained=pretrained)
        self.model.eval().to(device)
        self.tokenizer = get_tokenizer(model_name)
        print("✅ Text encoder loaded!")
        
    @torch.no_grad()
    def encode_text(self, texts):
        tokens = self.tokenizer(texts).to(self.device)
        features = self.model.encode_text(tokens)
        features = features / features.norm(dim=-1, keepdim=True)
        return features
    
    @torch.no_grad()
    def compute_all_text_features(self, class_names, templates=IMAGENET_TEMPLATES):
        per_template_features = []
        for template in tqdm(templates, desc="Computing text features"):
            prompts = [template.format(cls) for cls in class_names]
            features = self.encode_text(prompts)
            per_template_features.append(features.cpu())
        per_template_features = torch.stack(per_template_features)
        averaged_features = per_template_features.mean(dim=0)
        averaged_features = averaged_features / averaged_features.norm(dim=-1, keepdim=True)
        return {"per_template": per_template_features, "averaged": averaged_features}

print("📦 CLIPTextEncoder defined")

## 📊 Metric Functions

In [None]:
# Entropy and Other Metrics
def compute_entropy(probs, dim=1, eps=1e-10):
    """Lower entropy = higher confidence = better template."""
    log_probs = torch.log(probs + eps)
    return -(probs * log_probs).sum(dim=dim)

def compute_max_prob(probs, dim=1):
    """Higher max prob = higher confidence = better template."""
    return probs.max(dim=dim)[0]

def compute_margin(probs, dim=1):
    """Higher margin = more decisive = better template."""
    sorted_probs = probs.sort(dim=dim, descending=True)[0]
    return sorted_probs.select(dim, 0) - sorted_probs.select(dim, 1)

METRICS = {
    "entropy": {"fn": compute_entropy, "lower_is_better": True},
    "max_prob": {"fn": compute_max_prob, "lower_is_better": False},
    "margin": {"fn": compute_margin, "lower_is_better": False},
}

print("📊 Metrics defined: entropy, max_prob, margin")

In [None]:
# 🎯 YOUR CUSTOM METRIC - Edit this to beat entropy!
def compute_custom_metric(probs, dim=1, eps=1e-10):
    """
    🎯 YOUR CUSTOM METRIC!
    
    Implement your own metric here to beat entropy!
    """
    # Example: Weighted combination of entropy and margin
    entropy = compute_entropy(probs, dim=dim, eps=eps)
    margin = compute_margin(probs, dim=dim)
    return entropy - 0.3 * margin

METRICS["custom"] = {"fn": compute_custom_metric, "lower_is_better": True}
print("✅ Custom metric registered!")

## 🔄 Compute Text Features

In [None]:
# Load or compute text features
print("🧮 Computing text features...")
text_encoder = CLIPTextEncoder()
text_features = text_encoder.compute_all_text_features(current_classes)

print(f"\n📊 Text features ready:")
print(f"   Per-template: {text_features['per_template'].shape}")
print(f"   Averaged: {text_features['averaged'].shape}")

## 📈 Template Ranking

In [None]:
# Template Ranker
class TemplateRanker:
    def __init__(self, text_features, class_names, templates=IMAGENET_TEMPLATES, temperature=0.01):
        self.text_features = text_features
        self.class_names = class_names
        self.templates = templates
        self.temperature = temperature
        self.num_templates = len(templates)
        self.num_classes = len(class_names)
        self.reset()
        
    def reset(self):
        self.metric_accumulator = defaultdict(lambda: defaultdict(lambda: {"sum": 0.0, "count": 0}))
        self.total_pixels = 0
        
    def compute_segmentation_probs(self, image_features):
        B, dim, H, W = image_features.shape
        per_template = self.text_features["per_template"]
        T, C, D = per_template.shape
        reshaped = per_template.reshape(-1, D)
        output = F.conv2d(image_features, reshaped[:, :, None, None])
        output = output.reshape(B, T, C, H, W)
        probs = F.softmax(output / self.temperature, dim=2)
        return probs.permute(1, 0, 2, 3, 4)
    
    def update_metrics(self, image_features, metric_name="entropy"):
        B, _, H, W = image_features.shape
        probs = self.compute_segmentation_probs(image_features)
        metric_fn = METRICS[metric_name]["fn"]
        
        for t_idx in range(self.num_templates):
            template_probs = probs[t_idx]
            metric_values = metric_fn(template_probs, dim=1)
            predictions = template_probs.argmax(dim=1)
            
            for c_idx in range(self.num_classes):
                mask = predictions == c_idx
                if mask.any():
                    class_metric = metric_values[mask]
                    self.metric_accumulator[t_idx][c_idx]["sum"] += class_metric.sum().item()
                    self.metric_accumulator[t_idx][c_idx]["count"] += mask.sum().item()
        
        self.total_pixels += B * H * W
        
    def compute_rankings(self, metric_name="entropy"):
        lower_is_better = METRICS[metric_name]["lower_is_better"]
        rankings = {"classes": {}}
        
        for c_idx, class_name in enumerate(self.class_names):
            template_scores = []
            for t_idx in range(self.num_templates):
                data = self.metric_accumulator[t_idx][c_idx]
                if data["count"] > 0:
                    avg_metric = data["sum"] / data["count"]
                    pixel_pct = (data["count"] / self.total_pixels) * 100
                else:
                    avg_metric = float('inf') if lower_is_better else float('-inf')
                    pixel_pct = 0.0
                template_scores.append({"template_id": t_idx, metric_name: avg_metric, "pixel_percentage": pixel_pct})
            
            sorted_scores = sorted(template_scores, key=lambda x: x[metric_name], reverse=not lower_is_better)
            for rank, score in enumerate(sorted_scores, 1):
                score["rank"] = rank
            rankings["classes"][class_name] = {f"{metric_name}_ranking": sorted_scores}
        
        return rankings

print("📦 TemplateRanker defined")

## 🎯 Demo with Synthetic Features

In [None]:
# Generate synthetic features for demo
def generate_synthetic_features(num_images=50, height=28, width=28, dim=512):
    features = []
    for _ in tqdm(range(num_images), desc="Generating features"):
        feat = torch.randn(1, dim, height, width)
        feat = feat / feat.norm(dim=1, keepdim=True)
        features.append(feat)
    return features

synthetic_features = generate_synthetic_features(num_images=50)
print(f"✅ Generated {len(synthetic_features)} synthetic feature maps")

In [None]:
# Compute Rankings
METRIC_NAME = "entropy"  # Try: entropy, max_prob, margin, custom

print(f"📊 Computing rankings using: {METRIC_NAME}")

ranker = TemplateRanker(text_features=text_features, class_names=current_classes)

for feat in tqdm(synthetic_features, desc="Processing"):
    ranker.update_metrics(feat, metric_name=METRIC_NAME)

rankings = ranker.compute_rankings(metric_name=METRIC_NAME)
print(f"\n✅ Rankings computed!")

In [None]:
# Display Rankings
print(f"\n🏆 Top-3 Templates per Class (by {METRIC_NAME})")
print("=" * 70)

ranking_field = f"{METRIC_NAME}_ranking"
for class_name, class_data in list(rankings["classes"].items())[:5]:
    top_templates = class_data[ranking_field][:3]
    print(f"\n📌 {class_name}:")
    for t in top_templates:
        template_str = IMAGENET_TEMPLATES[t['template_id']].format(class_name)
        print(f"   [{t['rank']:2d}] T{t['template_id']:2d} | {METRIC_NAME}={t[METRIC_NAME]:.4f} | \"{template_str}\"")

## 📊 Visualization

In [None]:
# Visualize Rankings
def visualize_rankings(rankings, metric_name, num_classes=8):
    ranking_field = f"{metric_name}_ranking"
    class_names = list(rankings["classes"].keys())[:num_classes]
    
    rank_matrix = np.zeros((len(class_names), 80))
    for i, cn in enumerate(class_names):
        for t_data in rankings["classes"][cn][ranking_field]:
            rank_matrix[i, t_data["template_id"]] = t_data["rank"]
    
    plt.figure(figsize=(16, 6))
    im = plt.imshow(rank_matrix, cmap="RdYlGn_r", aspect="auto")
    plt.yticks(range(len(class_names)), class_names)
    plt.xlabel("Template ID")
    plt.ylabel("Class")
    plt.title(f"Template Rankings by {metric_name.capitalize()}")
    plt.colorbar(im, label="Rank")
    plt.tight_layout()
    plt.show()

visualize_rankings(rankings, METRIC_NAME)

## 🏁 Load Paper's Rankings

In [None]:
# Load Paper Rankings
def load_paper_rankings(dataset, model="naclip"):
    if IN_COLAB:
        path = f"/content/FLOSS/rankings/{model}/template_rankings_{model}_{dataset}.json"
    else:
        path = f"rankings/{model}/template_rankings_{model}_{dataset}.json"
    
    try:
        with open(path, "r") as f:
            return json.load(f)
    except FileNotFoundError:
        print(f"⚠️ Rankings not found at {path}")
        return None

paper_rankings = load_paper_rankings(DATASET, MODEL_TYPE)

if paper_rankings:
    print("📄 Paper's Entropy Rankings (Top 3):")
    for cn in list(paper_rankings["classes"].keys())[:3]:
        ranking = paper_rankings["classes"][cn].get("entropy_ranking", [])[:3]
        print(f"\n   {cn}: {[t['template_id'] for t in ranking]}")

## 🎉 Summary

You've learned how to:
- Compute template rankings using entropy
- Experiment with custom metrics
- Visualize and compare results

**Challenge:** Can you find a better metric than entropy? Edit `compute_custom_metric()`!