# Step 6: Iteration Loop

Now you have a trained model and know where it fails. This step shows how to:
1. Run a **Golden QA check** to detect annotation drift
2. Select the **next batch** using a hybrid strategy

The hybrid strategy balances:
- **30% Coverage** - Diversity sampling to avoid tunnel vision
- **70% Targeted** - Samples similar to failures

This balance is critical. Only chasing failures creates a model that's great at edge cases and terrible at normal cases.

In [None]:
!pip install -q scikit-learn

In [None]:
import fiftyone as fo
import fiftyone.brain as fob
from fiftyone import ViewField as F
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from collections import Counter

LABEL_FIELD = "human_labels"

dataset = fo.load_dataset("annotation_tutorial")

# Get schema classes
if "annotation_schema" in dataset.info:
    SCHEMA_CLASSES = set(dataset.info["annotation_schema"]["classes"])
else:
    SCHEMA_CLASSES = {
        "person", "car", "truck", "bus", "motorcycle", "bicycle",
        "dog", "cat", "bird", "horse",
        "chair", "couch", "dining table", "tv",
        "bottle", "cup", "bowl", "other"
    }

## Golden QA Check

Before selecting the next batch, verify annotation quality hasn't drifted. The golden set is a small, carefully reviewed sample we check each iteration.

**What to look for:**
- Label count distribution staying stable
- No unexpected empty samples
- Class distribution roughly matching earlier rounds

In [None]:
# Load golden QA set
golden = dataset.load_saved_view("golden_qa")

# For tutorial, copy ground_truth to human_labels if not present
# FILTERED to schema classes for consistency
for sample in golden:
    if sample.ground_truth and not sample[LABEL_FIELD]:
        filtered_dets = [
            fo.Detection(label=d.label, bounding_box=d.bounding_box)
            for d in sample.ground_truth.detections
            if d.label in SCHEMA_CLASSES
        ]
        sample[LABEL_FIELD] = fo.Detections(detections=filtered_dets)
        sample.save()

print(f"Golden QA set: {len(golden)} samples")

In [None]:
# Golden QA Check: Compute baseline stats
golden_stats = {
    "total_samples": len(golden),
    "samples_with_labels": 0,
    "total_detections": 0,
    "class_counts": Counter()
}

for sample in golden:
    if sample[LABEL_FIELD] and len(sample[LABEL_FIELD].detections) > 0:
        golden_stats["samples_with_labels"] += 1
        golden_stats["total_detections"] += len(sample[LABEL_FIELD].detections)
        for det in sample[LABEL_FIELD].detections:
            golden_stats["class_counts"][det.label] += 1

print("=" * 40)
print("GOLDEN QA CHECK")
print("=" * 40)
print(f"Samples with labels: {golden_stats['samples_with_labels']}/{golden_stats['total_samples']}")
print(f"Total detections: {golden_stats['total_detections']}")
print(f"Avg detections/sample: {golden_stats['total_detections']/max(1,golden_stats['samples_with_labels']):.1f}")
print(f"\nTop classes:")
for cls, count in golden_stats["class_counts"].most_common(5):
    print(f"  {cls}: {count}")
print("=" * 40)
print("\nIf these numbers change unexpectedly between iterations,")
print("investigate annotation consistency before continuing.")

In [None]:
# Store golden stats for comparison in future iterations
if "golden_qa_history" not in dataset.info:
    dataset.info["golden_qa_history"] = []

dataset.info["golden_qa_history"].append({
    "iteration": len(dataset.info["golden_qa_history"]),
    "samples_with_labels": golden_stats["samples_with_labels"],
    "total_detections": golden_stats["total_detections"],
    "top_classes": dict(golden_stats["class_counts"].most_common(5))
})
dataset.save()

print(f"Saved golden QA stats (iteration {len(dataset.info['golden_qa_history'])-1})")

## Prepare for Next Batch Selection

In [None]:
# Get truly unlabeled samples from pool
# This excludes: annotated, selected for batch, pending annotation
pool = dataset.load_saved_view("pool")

# Use annotation_status to find truly unlabeled samples
remaining = pool.match(F("annotation_status") == "unlabeled")

# Get failure samples from evaluation
try:
    failures = dataset.load_saved_view("eval_v0_failures")
    print(f"Failure samples: {len(failures)}")
except:
    failures = dataset.limit(0)  # Empty view
    print("No failure view found. Run Step 5 first, or continue with coverage-only selection.")

print(f"Remaining unlabeled pool: {len(remaining)} samples")

## Define Acquisition Budget

**Batch sizing guidance:**
- Size batches to your labeling capacity (e.g., half-day to one-day of work)
- For this tutorial, we'll select ~20% of remaining pool

In [None]:
# Select batch size based on remaining pool
# Minimum 10 samples, or 20% of remaining
batch_size = max(10, int(0.20 * len(remaining)))

# Split: 30% coverage (ZCore), 70% targeted
coverage_budget = int(0.30 * batch_size)
targeted_budget = batch_size - coverage_budget

print(f"Batch v1 budget: {batch_size} samples")
print(f"  Coverage (diversity): {coverage_budget} (30%)")
print(f"  Targeted (failures): {targeted_budget} (70%)")

## Part 1: Coverage Selection (30%)

Use ZCore scores computed in Step 3 to select diverse samples from remaining pool.

In [None]:
# Get samples with ZCore scores from remaining pool
remaining_with_scores = remaining.match(F("zcore") != None)

if len(remaining_with_scores) == 0:
    print("No ZCore scores found in remaining pool. Run Step 3 first.")
    coverage_ids = set()
else:
    # Select top by ZCore (already computed in Step 3)
    coverage_samples = remaining_with_scores.sort_by("zcore", reverse=True).limit(coverage_budget)
    coverage_ids = set(s.id for s in coverage_samples)
    print(f"Coverage selection (ZCore): {len(coverage_ids)} samples")

## Part 2: Targeted Selection (70%)

Find samples similar to failures using embedding-based neighbor search.

In [None]:
def find_neighbors(query_embs, pool_embs, pool_ids, n_per_query=3):
    """Find nearest neighbors in embedding space."""
    if len(query_embs) == 0 or len(pool_embs) == 0:
        return []
    
    sims = cosine_similarity(query_embs, pool_embs)
    neighbor_ids = set()
    
    for sim_row in sims:
        top_idx = np.argsort(sim_row)[-n_per_query:]
        for idx in top_idx:
            neighbor_ids.add(pool_ids[idx])
    
    return list(neighbor_ids)

In [None]:
# Get embeddings for remaining samples
remaining_samples = list(remaining)
remaining_embs = np.array([s.embeddings for s in remaining_samples if s.embeddings is not None])
remaining_ids = [s.id for s in remaining_samples if s.embeddings is not None]

if len(failures) > 0 and len(remaining_embs) > 0:
    failure_embs = np.array([s.embeddings for s in failures if s.embeddings is not None])
    print(f"Finding neighbors of {len(failure_embs)} failure samples...")
    
    # Find neighbors (excluding already-selected coverage samples)
    failure_neighbors = find_neighbors(failure_embs, remaining_embs, remaining_ids, n_per_query=5)
    targeted_ids = [sid for sid in failure_neighbors if sid not in coverage_ids][:targeted_budget]
    print(f"Targeted selection: {len(targeted_ids)} samples")
else:
    print("No failures to target or no embeddings. Using coverage-only selection.")
    # Fall back to more coverage samples
    if len(remaining_with_scores) > coverage_budget:
        extra_coverage = remaining_with_scores.sort_by("zcore", reverse=True).skip(coverage_budget).limit(targeted_budget)
        targeted_ids = [s.id for s in extra_coverage if s.id not in coverage_ids]
    else:
        targeted_ids = []
    print(f"Additional coverage selection: {len(targeted_ids)} samples")

## Combine and Tag Batch v1

In [None]:
# Combine selections
batch_v1_ids = list(coverage_ids) + targeted_ids

if len(batch_v1_ids) == 0:
    print("No samples selected. Check that Steps 3 and 5 completed successfully.")
else:
    batch_v1 = dataset.select(batch_v1_ids)

    # Tag
    batch_v1.tag_samples("batch:v1")
    batch_v1.tag_samples("to_annotate")
    batch_v1.set_values("annotation_status", ["selected"] * len(batch_v1))

    # Track source for analysis
    if len(coverage_ids) > 0:
        dataset.select(list(coverage_ids)).tag_samples("source:coverage")
    if len(targeted_ids) > 0:
        dataset.select(targeted_ids).tag_samples("source:targeted")

    # Save view
    dataset.save_view("batch_v1", dataset.match_tags("batch:v1"))

    print(f"\nBatch v1: {len(batch_v1)} samples")
    print(f"  Coverage: {len(coverage_ids)}")
    print(f"  Targeted: {len(targeted_ids)}")

## The Complete Loop

You now have the full iteration recipe:

```
1. Run Golden QA check (detect drift)
2. Annotate the current batch (Step 4 workflow)
3. Train on all annotated data (Step 5)
4. Evaluate on val set, tag failures
5. Select next batch: 30% coverage + 70% targeted
6. Repeat until stopping criteria
```

### Stopping Criteria

Stop when:
- Gains per labeled sample flatten (diminishing returns)
- Remaining failures are mostly label ambiguity
- Val metrics hit your target threshold

### The 30% Coverage Rule

**Don't skip the coverage budget.** Only chasing failures leads to:
- Overfitting to edge cases
- Distorted class priors
- Models that fail on "normal" inputs

Coverage keeps you honest.

In [None]:
# Progress summary
pool = dataset.load_saved_view("pool")
total_pool = len(pool)

annotated = len(dataset.match_tags("annotated:v0"))
selected_v1 = len(dataset.match_tags("batch:v1"))
still_unlabeled = len(pool.match(F("annotation_status") == "unlabeled"))

print("=" * 40)
print("ANNOTATION PROGRESS")
print("=" * 40)
print(f"Pool total:      {total_pool}")
print(f"Annotated (v0):  {annotated} ({100*annotated/total_pool:.0f}%)")
print(f"Selected (v1):   {selected_v1} ({100*selected_v1/total_pool:.0f}%)")
print(f"Still unlabeled: {still_unlabeled} ({100*still_unlabeled/total_pool:.0f}%)")
print("=" * 40)

## Summary

You implemented the iteration loop:
- **Golden QA check** to detect annotation drift
- **Hybrid acquisition**: 30% coverage + 70% targeted
- Tagged `batch:v1` ready for annotation

**Why this works:** 
- Coverage prevents overfitting to edge cases
- Targeting fixes known failures
- Golden QA catches annotation drift early
- The combination improves faster than either strategy alone

**Your turn:** Repeat Steps 4-6 with batch_v1, then batch_v2, etc.