Fine-tune Google's PaliGemma vision-language model on the Flickr8k dataset for image captioning tasks.
PaliGemma is a powerful vision-language model that combines image understanding with text generation capabilities. This project demonstrates how to fine-tune the model on a custom dataset (Flickr8k) to improve its image captioning performance.
- Fine-tuning PaliGemma on Flickr8k dataset
- Custom training loop with progress tracking
- Model saving and loading utilities
- Batch inference for testing
- Visual result rendering
- JAX/Flax implementation with efficient sharding
pip install ml_collections
pip install tensorflow
pip install sentencepiece
pip install jax
pip install sklearn
pip install pandas
pip install numpy
pip install pillow
pip install kagglehub # For downloading model weights
Additional Requirements:
- Flickr8k dataset from Kaggle: https://www.kaggle.com/datasets/adityajn105/flickr8k
- Google's big_vision repository (automatically cloned)
- PaliGemma model weights (automatically downloaded via kagglehub)
- PaliGemma tokenizer (automatically downloaded from Google Cloud Storage)
โโโ finetune-paligemma/
โ โโโ finetune-paligemma.ipynb # Training and testing script
โโโ PaliGemma_test_script.py # Standalone test script
โโโ README.md
The project uses the Flickr8k dataset from Kaggle: https://www.kaggle.com/datasets/adityajn105/flickr8k
Dataset Details:
- 8,000 images with human-annotated captions
- Multiple captions per image (5 captions per image)
- Diverse scene descriptions covering people, animals, activities, and outdoor scenes
- High-quality annotations suitable for vision-language tasks
flickr8k/
โโโ captions.txt # CSV file with image-caption pairs
โโโ Images/ # Directory containing 8,000 JPEG images
โโโ 1000268201_693b08cb0e.jpg
โโโ 1001773457_577c3a7d70.jpg
โโโ ...
- Download the dataset from Kaggle: https://www.kaggle.com/datasets/adityajn105/flickr8k
- Extract to your desired location
- Update the
DATA_DIR
path in the scripts to point to your dataset location
The training script (finetune-paligemma.ipynb
) includes:
- Model Configuration: PaliGemma 3B parameter model with specific image and language settings
- Fine-tuning Strategy: Only attention layers are trainable, keeping other parameters frozen
- Data Processing: Robust image preprocessing and tokenization
- Training Loop: 10 epochs with batch size of 10
- Progress Tracking: Real-time loss monitoring and time estimation
- Model Saving: Custom parameter saving with fallback mechanisms
Parameter | Value |
---|---|
Learning Rate | 5e-5 |
Batch Size | 10 |
Epochs | 10 |
Sequence Length | 128 tokens |
Image Size | 224x224 pixels |
Dataset Size | 1,500 samples (1,200 train / 300 validation) |
# Download Flickr8k dataset from Kaggle
# https://www.kaggle.com/datasets/adityajn105/flickr8k
# Update DATA_DIR path in the script to point to your dataset location
DATA_DIR = "/path/to/flickr8k"
# The training script automatically:
# 1. Downloads PaliGemma model weights via kagglehub
# 2. Downloads tokenizer from Google Cloud Storage
# 3. Loads and preprocesses Flickr8k dataset (1,500 samples)
# 4. Fine-tunes only attention layers
# 5. Saves the trained model to /kaggle/working/PaliGemma_Fine_Tune_Flickr8k/
# (update location if running locally)
# Or
# you can just copy and paste the code from finetune-paligemma.ipynb to kaggle notebook and add the dataset from the input option in kaggle and just run the code and it will execute.
The inference scripts (finetune-paligemma.ipynb
and PaliGemma_test_script.py
) provide:
- Load fine-tuned model weights
- Batch processing of test images
- Caption generation with customizable sampling
- Visual result rendering
- Robust error handling
# Ensure the fine-tuned model exists at the specified path
MODEL_PATH = "/kaggle/working/PaliGemma_Fine_Tune_Flickr8k/paligemma_flickr8k.params.f16.npz"
# Update the location to the actual path if running locally
# Update paths in the script if needed
DATA_DIR = "/path/to/flickr8k"
IMAGES_DIR = "/path/to/flickr8k/Images"
# Run inference
python PaliGemma_test_script.py
Component | Specification |
---|---|
Base Model | PaliGemma 3B |
Vision Encoder | SigLIP-So400m/14 |
Language Model | Custom LLM with 257,152 vocabulary |
Fine-tuning | Attention layers only |
Input Resolution | 224x224 |
Max Sequence Length | 128 tokens |
- Data Loading: Flickr8k images and captions are loaded and preprocessed (1000 images used due to system constraints - you can use full 8000 with increased epochs for better results)
- Model Setup: PaliGemma model is initialized with pretrained weights
- Parameter Masking: Only attention layers are marked as trainable
- Training Loop: Model is trained for 10 epochs with progress monitoring
- Validation: Model performance is evaluated on held-out validation set
- Model Saving: Trained parameters are saved for later inference
- Training completed over 10 epochs
- Loss reduction observed during training
- Validation performed on 20% held-out data
Image | Generated Caption |
---|---|
![]() |
a biker is doing a wheelie on a yellow motorcycle |
![]() |
three women dressed in green for the parade |
![]() |
a brown and white dog standing in the water |
1. Image Preprocessing:
- Resize to 224x224
- Normalize to [-1, 1] range
- Handle grayscale to RGB conversion
2. Text Processing:
- SentencePiece tokenization
- Proper masking for training
- EOS token handling
3. Training Strategy:
- Selective parameter freezing
- Gradient accumulation
- Memory-efficient sharding
4. Inference Pipeline:
- Batch processing
- Greedy decoding
- Post-processing cleanup
- Memory Usage: Efficient parameter sharding for large model
- Training Speed: ~120 steps per epoch with batch size 10
- Inference Speed: Batch processing for efficient caption generation
- Model Size: Fine-tuned model maintains original parameter count
Issue | Solution |
---|---|
Memory Errors | Reduce batch size or use gradient checkpointing |
Loading Errors | Ensure proper model path and file permissions |
Dataset Issues | Verify Flickr8k dataset structure and file paths |
JAX Issues | Ensure proper JAX installation and device configuration |
Additional Solutions:
- Custom model saving/loading functions included for robustness
- Fallback mechanisms for parameter serialization
- Error handling for corrupted images
- Progress tracking for long training runs
- Experiment with different learning rates
- Add BLEU/ROUGE evaluation metrics
- Implement beam search decoding
- Fine-tune on additional datasets
- Add multi-GPU training support
- Implement gradual unfreezing strategy
- Google Research for PaliGemma model
- Big Vision team for the implementation framework
- Flickr8k dataset creators
- JAX/Flax development team