This project extends the CoCoNuT (Chain of Continuous Thought) methodology to multimodal reasoning, combining InternVL3's vision-language capabilities with continuous latent reasoning for visual question answering tasks.
CoCoNuT represents a paradigm shift from discrete textual reasoning steps to continuous thought vectors, allowing models to reason in a high-dimensional latent space. This multimodal extension adapts this approach to handle image-text pairs, creating a system that can perform visual question answering through continuous latent reasoning.
- Continuous Thought Mechanism: Replaces discrete reasoning steps with continuous vector representations
- Staged Curriculum Learning: Progressive training from explicit reasoning to latent thoughts
- Multimodal Integration: Built on InternVL3-1B-Pretrained for vision-language understanding
- Distributed Training: Support for FSDP and DDP training across multiple GPUs
- Flexible Configuration: YAML-based configuration system following CoCoNuT patterns
- Python: 3.8 or higher (3.9 recommended)
- CUDA: Compatible GPU with CUDA 11.8+ (for GPU acceleration)
- Memory: At least 16GB RAM, 8GB+ GPU memory recommended
- Storage: 10GB+ free space for models and data
- Clone the repository:
git clone <repository-url>
cd multimodal-coconut- Create and activate environment:
# Using conda (recommended)
conda create -n multimodal-coconut python=3.9 -y
conda activate multimodal-coconut
# Or using venv
python -m venv multimodal-coconut
source multimodal-coconut/bin/activate # Linux/Mac
# multimodal-coconut\Scripts\activate # Windows- Install PyTorch (choose based on your CUDA version):
# CUDA 11.8
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# CUDA 12.1
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
# CPU only (not recommended for training)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu- Install dependencies:
pip install -r requirements.txt- Install the package:
pip install -e .Run the infrastructure test to verify everything is working:
python test_infrastructure.pyExpected output:
✓ All imports successful
✓ Configuration system working
✓ Model architecture functional
✓ Data pipeline operational
✓ Training utilities ready
✓ Infrastructure test passed!
If you encounter issues, see our Troubleshooting Guide for common solutions.
Common fixes:
- CUDA issues: Check
nvidia-smiand install matching PyTorch version - Memory errors: Ensure sufficient RAM/GPU memory
- Package conflicts: Use a fresh virtual environment
The project uses YAML configuration files following the original CoCoNuT patterns. Example configurations are provided in the args/ directory:
args/multimodal_cot.yaml: Chain-of-Thought pre-training configurationargs/multimodal_coconut.yaml: CoCoNuT training configuration
- Stage 0 - CoT Pre-training:
torchrun --nnodes 1 --nproc_per_node <N_GPUS> run.py args/multimodal_cot.yaml- CoCoNuT Training:
torchrun --nnodes 1 --nproc_per_node <N_GPUS> run.py args/multimodal_coconut.yaml# Set only_eval: true in your config file
torchrun --nnodes 1 --nproc_per_node <N_GPUS> run.py args/multimodal_coconut_eval.yamlmultimodal-coconut/
├── multimodal_coconut/ # Main package
│ ├── config/ # Configuration management
│ ├── data/ # Data loading and processing
│ ├── model/ # Model implementations
│ ├── training/ # Training utilities
│ └── utils/ # Utility functions
├── args/ # Configuration files
├── reference/ # Reference implementations
├── requirements.txt # Dependencies
├── run.py # Main training script
└── test_infrastructure.py # Infrastructure test script
The core innovation is replacing discrete textual reasoning steps with continuous vector representations:
- Latent Tokens: Special
<|latent|>tokens mark positions for continuous thoughts - Hidden State Feedback: Previous hidden states become input embeddings for latent tokens
- Iterative Processing: Multiple forward passes handle the dependency chain
- KV Cache Optimization: Efficient reuse of Key/Value matrices
Training progresses through stages:
- Stage 0: Standard multimodal chain-of-thought
- Stage k: First k reasoning steps replaced with latent tokens
- Progressive Deepening: Gradual increase in continuous reasoning
Key configuration parameters:
# CoCoNuT parameters
c_thought: 2 # Continuous thoughts per reasoning step
max_latent_stage: 4 # Maximum latent stage
epochs_per_stage: 5 # Epochs per curriculum stage
# Model settings
model_id: "OpenGVLab/InternVL3-1B-Pretrained"
image_size: 448
max_num_patches: 12
# Training settings
batch_size_training: 8
learning_rate: 1e-5
num_epochs: 40- API Documentation - Comprehensive API documentation for all classes and functions
- Troubleshooting Guide - Solutions for common issues and problems
- Configuration System: YAML-based configuration with validation and templates
- Model Architecture: MultimodalCoconut class extending InternVL3 with continuous thoughts
- Data Pipeline: Efficient multimodal data loading and preprocessing
- Training System: Staged curriculum learning with distributed training support
- Utilities: Logging, distributed training, and debugging tools
from multimodal_coconut import load_config, MultimodalCoconut
from multimodal_coconut.training import MultimodalCoTTrainer
# Load configuration
config = load_config('args/multimodal_coconut.yaml')
# Initialize model and trainer
model = MultimodalCoconut.from_pretrained(config.model_id, config)
trainer = MultimodalCoTTrainer(model, tokenizer, image_processor, config)
# Start training
trainer.train()import torch
from PIL import Image
from multimodal_coconut import MultimodalCoconut
# Load trained model
model = MultimodalCoconut.from_pretrained('path/to/checkpoint')
model.eval()
# Process image and question
image = Image.open('example.jpg')
question = "What is happening in this image?"
# Generate response
with torch.no_grad():
response = model.generate(
pixel_values=process_image(image),
input_ids=tokenize_text(question),
max_new_tokens=100
)
print(f"Answer: {decode_response(response)}")from multimodal_coconut import create_config_from_template
# Create custom configuration
config = create_config_from_template(
'coconut',
c_thought=3,
max_latent_stage=6,
batch_size_training=4
)
# Save for later use
config.save('args/my_experiment.yaml')This project is currently in active development. The infrastructure and core components have been implemented and tested:
✅ Completed:
- Configuration system with validation
- Model architecture (MultimodalCoconut)
- Data pipeline components
- Training infrastructure
- Distributed training support
- Comprehensive documentation
- Testing framework
🚧 In Progress:
- Full integration testing
- Performance optimization
- Advanced evaluation metrics
📋 Planned:
- Additional dataset support
- Model compression techniques
- Deployment utilities
This project follows the original CoCoNuT's elegant simplicity. When contributing:
- Follow the existing code patterns and style
- Keep implementations minimal and focused
- Add comprehensive tests for new features
- Update documentation as needed
This project is licensed under the MIT License - see the LICENSE file for details.
- Original CoCoNuT paper and implementation
- InternVL3 team for the base multimodal model
- A-OKVQA dataset creators
If you use this work, please cite:
@article{coconut2024,
title={CoCoNuT: Reasoning in a Continuous Latent Space},
author={...},
journal={arXiv preprint arXiv:2412.06769},
year={2024}
}