In [None]:
# Install required packages (run once)
# pip install groq datasets torch torchvision pillow


In [None]:
# Setup: Import libraries and set API key
import os
import sys
from pathlib import Path
import json

# Add current directory to path
sys.path.insert(0, str(Path.cwd()))

from groq_caption_generator import GroqCaptionGenerator
from data_loaders import ImageFolderWithMetadata, HuggingFaceDatasetWrapper, prepare_imagefolder_dataset
from dataset_processors import FFHQProcessor, EasyPortraitProcessor, LAIONFaceProcessor

# Set your Groq API key
GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "")

if not GROQ_API_KEY:
    print("Warning: GROQ_API_KEY not set!")
else:
    print("✓ Groq API key loaded")


## Step 1: Prepare Dataset Structure

Choose your dataset and prepare the ImageFolder structure with metadata.jsonl file.


In [None]:
# Example: Prepare FFHQ dataset
# Uncomment and modify paths as needed

# FFHQ_SOURCE = "./ffhq/images"
# FFHQ_OUTPUT = "./data/ffhq_processed"
# 
# processor = FFHQProcessor(FFHQ_SOURCE, FFHQ_OUTPUT)
# processor.prepare_structure()


## Step 2: Generate Captions (Distributed Processing)

Each team member should process a different batch by setting `START_INDEX` and `END_INDEX`.

**Distribution for 70k FFHQ images across 5 team members:**
- Member 1: indices 0-14,000
- Member 2: indices 14,000-28,000
- Member 3: indices 28,000-42,000
- Member 4: indices 42,000-56,000
- Member 5: indices 56,000-70,000


In [None]:
# Configuration for your batch
DATASET_DIR = "./data/ffhq_processed/images"  # Path to images directory
OUTPUT_FILE = "./data/ffhq_processed/metadata.jsonl"  # Output metadata file
START_INDEX = 0  # starting index
END_INDEX = None  # ending index (None = process all remaining)

# Initialize caption generator
generator = GroqCaptionGenerator(api_key=GROQ_API_KEY)

# Get all image paths
from pathlib import Path
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.webp'}
dataset_path = Path(DATASET_DIR)
image_paths = sorted([
    str(p) for p in dataset_path.iterdir()
    if p.suffix.lower() in image_extensions
])

print(f"Total images: {len(image_paths)}")
print(f"Processing indices: {START_INDEX} to {END_INDEX if END_INDEX else len(image_paths)}")


In [None]:
# Generate captions for your batch
import json

if END_INDEX is None:
    END_INDEX = len(image_paths)

batch_paths = image_paths[START_INDEX:END_INDEX]
results = []

print(f"Processing {len(batch_paths)} images...\n")

for idx, image_path in enumerate(batch_paths):
    try:
        image_name = Path(image_path).name
        print(f"[{idx+1}/{len(batch_paths)}] Processing: {image_name}")
        
        caption = generator.generate_caption(image_path)
        
        result = {
            "file_name": image_name,
            "text": caption
        }
        results.append(result)
        
        # Save incrementally
        with open(OUTPUT_FILE, 'a', encoding='utf-8') as f:
            f.write(json.dumps(result, ensure_ascii=False) + '\\n')
        
        print(f"  ✓ {caption[:80]}...")
        print()
        
    except Exception as e:
        print(f"  ✗ Error: {e}")
        print()
        continue

print(f"\\n{'='*80}")
print(f"Completed: {len(results)}/{len(batch_paths)} captions generated")
print(f"Results saved to: {OUTPUT_FILE}")
print(f"{'='*80}")


In [None]:
# Load and display sample captions
import json
from IPython.display import Image, display

# Load metadata
with open(OUTPUT_FILE, 'r', encoding='utf-8') as f:
    metadata = [json.loads(line) for line in f if line.strip()]

print(f"Total entries in metadata: {len(metadata)}\\n")

# Display first 5 examples
for i, entry in enumerate(metadata[:5]):
    print(f"{'='*80}")
    print(f"Example {i+1}:")
    print(f"File: {entry['file_name']}")
    print(f"Caption: {entry['text']}")
    print()
    
    # Display image if available
    image_path = Path(DATASET_DIR) / entry['file_name']
    if image_path.exists():
        display(Image(str(image_path), width=200))
    print()


## Step 4: Test Data Loader

Test that the dataset can be loaded correctly for training.


In [None]:
# Test ImageFolder dataset loader
from torch.utils.data import DataLoader

dataset_root = "./data/ffhq_processed"  # Path to your processed dataset

# Create dataset
dataset = ImageFolderWithMetadata(
    root_dir=dataset_root,
    metadata_file="metadata.jsonl",
    image_size=768,
    center_crop=True
)

print(f"Dataset size: {len(dataset)}")

# Test loading a sample
sample = dataset[0]
print(f"\\nSample data:")
print(f"  Image shape: {sample['pixel_values'].shape}")
print(f"  Text: {sample['text'][:100]}...")
print(f"  File name: {sample['file_name']}")

# Create dataloader
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
print(f"\\n✓ DataLoader created successfully")
print(f"  Batch size: 4")
print(f"  Number of batches: {len(dataloader)}")


## Step 5: Merge Multiple Datasets (Optional)

If processing multiple datasets (FFHQ, EasyPortrait, LAION-Face), merge them into a single dataset.


In [None]:
# Merge multiple processed datasets
from dataset_processors import merge_datasets

# List of processed dataset directories
dataset_dirs = [
    "./data/ffhq_processed",
    "./data/easyportrait_processed",
    "./data/laion_face_processed"
]

# Output directory for merged dataset
merged_output = "./data/merged_dataset"

# Merge datasets
merge_datasets(
    dataset_dirs=dataset_dirs,
    output_dir=merged_output,
    metadata_file="metadata.jsonl"
)
